{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ic4_occAAiAT"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ioaprt5q5US7"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sPrjeJhFQBmu"
      },
      "source": [
        "# Integrated gradients"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hKY4XMc9o8iB"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/interpretability/integrated_gradients\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/interpretability/integrated_gradients.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/interpretability/integrated_gradients.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/interpretability/integrated_gradients.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://tfhub.dev/google/imagenet/inception_v1/classification/4\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NG17_Wp6ikKf"
      },
      "source": [
        "This tutorial demonstrates how to implement **Integrated Gradients (IG)**, an [Explainable AI](https://en.wikipedia.org/wiki/Explainable_artificial_intelligence) technique introduced in the paper [Axiomatic Attribution for Deep Networks](https://arxiv.org/abs/1703.01365). IG aims to explain the relationship between a model's predictions in terms of its features. It has many use cases including understanding feature importances, identifying data skew, and debugging model performance.\n",
        "\n",
        "IG has become a popular interpretability technique due to its broad applicability to any differentiable model (e.g. images, text, structured data), ease of implementation, theoretical justifications, and computational efficiency relative to alternative approaches that allows it to scale to large networks and feature spaces such as images.\n",
        "\n",
        "In this tutorial, you will walk through an implementation of IG step-by-step to understand the pixel feature importances of an image classifier. As an example, consider this [image](https://commons.wikimedia.org/wiki/File:San_Francisco_fireboat_showing_off.jpg) of a fireboat spraying jets of water. You would classify this image as a fireboat and might highlight the pixels making up the boat and water cannons as being important to your decision. Your model will also classify this image as a fireboat later on in this tutorial; however, does it highlight the same pixels as important when explaining its decision?\n",
        "\n",
        "In the images below titled \"IG Attribution Mask\" and \"Original + IG Mask Overlay\" you can see that your model instead highlights (in purple) the pixels comprising the boat's water cannons and jets of water as being more important than the boat itself to its decision. How will your model generalize to new fireboats? What about fireboats without water jets? Read on to learn more about how IG works and how to apply IG to your models to better understand the relationship between their predictions and underlying features.\n",
        "\n",
        "![Output Image 1](images/IG_fireboat.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ppydw6ZbKzM1"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cbUMIubipgg0"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pylab as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GVVV4BGrABkA"
      },
      "source": [
        "### Download a pretrained image classifier from TF-Hub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7kwwJ35xmtoK"
      },
      "source": [
        "IG can be applied to any differentiable model. In the spirit of the original paper, you will use a pre-trained version of the same model, Inception V1, which you will download from [TensorFlow Hub](https://tfhub.dev/google/imagenet/inception_v1/classification/4)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "14APZcfHolKj"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "    hub.KerasLayer(\n",
        "        name='inception_v1',\n",
        "        handle='https://tfhub.dev/google/imagenet/inception_v1/classification/4',\n",
        "        trainable=False),\n",
        "])\n",
        "model.build([None, 224, 224, 3])\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GjLRn2e5xFOb"
      },
      "source": [
        "From the module page, you need to keep in mind the following about Inception V1:\n",
        "\n",
        "**Inputs**: The expected input shape for the model is `(None, 224, 224, 3)`. This is a dense 4D tensor of dtype float32 and shape `(batch_size, height, width, RGB channels)` whose elements are RGB color values of pixels normalized to the range [0, 1]. The first element is `None` to indicate that the model can take any integer batch size.\n",
        "\n",
        "**Outputs**: A `tf.Tensor` of logits in the shape of `(batch_size, 1001)`. Each row represents the model's predicted score for each of 1,001 classes from ImageNet. For the model's top predicted class index you can use `tf.argmax(predictions, axis=-1)`. Furthermore, you can also convert the model's logit output to predicted probabilities across all classes using `tf.nn.softmax(predictions, axis=-1)` to quantify the model's uncertainty as well as explore similar predicted classes for debugging."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "huZnb_O0L9mw"
      },
      "outputs": [],
      "source": [
        "def load_imagenet_labels(file_path):\n",
        "  labels_file = tf.keras.utils.get_file('ImageNetLabels.txt', file_path)\n",
        "  with open(labels_file) as reader:\n",
        "    f = reader.read()\n",
        "    labels = f.splitlines()\n",
        "  return np.array(labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rtrl-u7T6NEk"
      },
      "outputs": [],
      "source": [
        "imagenet_labels = load_imagenet_labels('https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "STpIJr1Z5r_u"
      },
      "source": [
        "### Load and preprocess images with `tf.image`\n",
        "\n",
        "You will illustrate IG using two images from [Wikimedia Commons](https://commons.wikimedia.org/wiki/Main_Page): a [Fireboat](https://commons.wikimedia.org/wiki/File:San_Francisco_fireboat_showing_off.jpg), and a [Giant Panda](https://commons.wikimedia.org/wiki/File:Giant_Panda_2.JPG)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YOb0Adq-rU5J"
      },
      "outputs": [],
      "source": [
        "def read_image(file_name):\n",
        "  image = tf.io.read_file(file_name)\n",
        "  image = tf.image.decode_jpeg(image, channels=3)\n",
        "  image = tf.image.convert_image_dtype(image, tf.float32)\n",
        "  image = tf.image.resize_with_pad(image, target_height=224, target_width=224)\n",
        "  return image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_khLTN75CLMJ"
      },
      "outputs": [],
      "source": [
        "img_url = {\n",
        "    'Fireboat': 'http://storage.googleapis.com/download.tensorflow.org/example_images/San_Francisco_fireboat_showing_off.jpg',\n",
        "    'Giant Panda': 'http://storage.googleapis.com/download.tensorflow.org/example_images/Giant_Panda_2.jpeg',\n",
        "}\n",
        "\n",
        "img_paths = {name: tf.keras.utils.get_file(name, url) for (name, url) in img_url.items()}\n",
        "img_name_tensors = {name: read_image(img_path) for (name, img_path) in img_paths.items()}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AYIeu8rMLN-8"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(8, 8))\n",
        "for n, (name, img_tensors) in enumerate(img_name_tensors.items()):\n",
        "  ax = plt.subplot(1, 2, n+1)\n",
        "  ax.imshow(img_tensors)\n",
        "  ax.set_title(name)\n",
        "  ax.axis('off')\n",
        "plt.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QvTYc8IZKbeO"
      },
      "source": [
        "### Classify images\n",
        "Let's start by classifying these images and displaying the top 3 most confident predictions. Following is a utility function to retrieve the top k predicted labels and probabilities."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5gsO7ILHZ0By"
      },
      "outputs": [],
      "source": [
        "def top_k_predictions(img, k=3):\n",
        "  image_batch = tf.expand_dims(img, 0)\n",
        "  predictions = model(image_batch)\n",
        "  probs = tf.nn.softmax(predictions, axis=-1)\n",
        "  top_probs, top_idxs = tf.math.top_k(input=probs, k=k)\n",
        "  top_labels = imagenet_labels[tuple(top_idxs)]\n",
        "  return top_labels, top_probs[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l80a8N2vcIP-"
      },
      "outputs": [],
      "source": [
        "for (name, img_tensor) in img_name_tensors.items():\n",
        "  plt.imshow(img_tensor)\n",
        "  plt.title(name, fontweight='bold')\n",
        "  plt.axis('off')\n",
        "  plt.show()\n",
        "\n",
        "  pred_label, pred_prob = top_k_predictions(img_tensor)\n",
        "  for label, prob in zip(pred_label, pred_prob):\n",
        "    print(f'{label}: {prob:0.1%}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v-lH1N4timM2"
      },
      "source": [
        "## Calculate Integrated Gradients"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FV7pfZHANaz1"
      },
      "source": [
        "Your model, Inception V1, is a learned function that describes a mapping between your input feature space, image pixel values, and an output space defined by ImageNet class probability values between 0 and 1. Early interpretability methods for neural networks assigned feature importance scores using gradients, which tell you which pixels have the steepest local relative to your model's prediction at a given point along your model's prediction function. However, gradients only describe *local* changes in your model's prediction function with respect to pixel values and do not fully describe your entire model prediction function. As your model fully \"learns\" the relationship between the range of an individual pixel and the correct ImageNet class, the gradient for this pixel will *saturate*, meaning become increasingly small and even go to zero. Consider the simple model function below:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0AUkIUvpkaO8"
      },
      "outputs": [],
      "source": [
        "def f(x):\n",
        "  \"\"\"A simplified model function.\"\"\"\n",
        "  return tf.where(x < 0.8, x, 0.8)\n",
        "\n",
        "def interpolated_path(x):\n",
        "  \"\"\"A straight line path.\"\"\"\n",
        "  return tf.zeros_like(x)\n",
        "\n",
        "x = tf.linspace(start=0.0, stop=1.0, num=6)\n",
        "y = f(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kMdAKooulVRE"
      },
      "outputs": [],
      "source": [
        "fig = plt.figure(figsize=(12, 5))\n",
        "ax0 = fig.add_subplot(121)\n",
        "ax0.plot(x, f(x), marker='o')\n",
        "ax0.set_title('Gradients saturate over F(x)', fontweight='bold')\n",
        "ax0.text(0.2, 0.5, 'Gradients > 0 = \\n x is important')\n",
        "ax0.text(0.7, 0.85, 'Gradients = 0 \\n x not important')\n",
        "ax0.set_yticks(tf.range(0, 1.5, 0.5))\n",
        "ax0.set_xticks(tf.range(0, 1.5, 0.5))\n",
        "ax0.set_ylabel('F(x) - model true class predicted probability')\n",
        "ax0.set_xlabel('x - (pixel value)')\n",
        "\n",
        "ax1 = fig.add_subplot(122)\n",
        "ax1.plot(x, f(x), marker='o')\n",
        "ax1.plot(x, interpolated_path(x), marker='>')\n",
        "ax1.set_title('IG intuition', fontweight='bold')\n",
        "ax1.text(0.25, 0.1, 'Accumulate gradients along path')\n",
        "ax1.set_ylabel('F(x) - model true class predicted probability')\n",
        "ax1.set_xlabel('x - (pixel value)')\n",
        "ax1.set_yticks(tf.range(0, 1.5, 0.5))\n",
        "ax1.set_xticks(tf.range(0, 1.5, 0.5))\n",
        "ax1.annotate('Baseline', xy=(0.0, 0.0), xytext=(0.0, 0.2),\n",
        "             arrowprops=dict(facecolor='black', shrink=0.1))\n",
        "ax1.annotate('Input', xy=(1.0, 0.0), xytext=(0.95, 0.2),\n",
        "             arrowprops=dict(facecolor='black', shrink=0.1))\n",
        "plt.show();"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LswhGFd0xZBr"
      },
      "source": [
        "* **left**: Your model's gradients for pixel `x` are positive between 0.0 and 0.8 but go to 0.0 between 0.8 and 1.0. Pixel `x` clearly has a significant impact on pushing your model toward 80% predicted probability on the true class. *Does it make sense that pixel `x`'s importance is small or discontinuous?*\n",
        "\n",
        "* **right**: The intuition behind IG is to accumulate pixel `x`'s local gradients and attribute its importance as a score for how much it adds or subtracts to your model's overall output class probability. You can break down and compute IG in 3 parts: \n",
        "  1. interpolate small steps along a straight line in the feature space between 0 (a baseline or starting point) and 1 (input pixel's value) \n",
        "  2. compute gradients at each step between your model's predictions with respect to each step \n",
        "  3. approximate the integral between your baseline and input by accumulating (cumulative average) these local gradients.\n",
        "\n",
        "To reinforce this intuition, you will walk through these 3 parts by applying IG to the example \"Fireboat\" image below. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "28MU35BCLM-s"
      },
      "source": [
        "### Establish a baseline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MIPG5yYfkydQ"
      },
      "source": [
        "A baseline is an input image used as a starting point for calculating feature importance. Intuitively, you can think of the baseline's explanatory role as representing the impact of the absence of each pixel on the \"Fireboat\" prediction to contrast with its impact of each pixel on the \"Fireboat\" prediction when present in the input image. As a result, the choice of the baseline plays a central role in interpreting and visualizing pixel feature importances. For additional discussion of baseline selection, see the resources in the \"Next steps\" section at the bottom of this tutorial. Here, you will use a black image whose pixel values are all zero.\n",
        "\n",
        "Other choices you could experiment with include an all white image, or a random image, which you can create with `tf.random.uniform(shape=(224,224,3), minval=0.0, maxval=1.0)`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wxvpwGkj4G4J"
      },
      "outputs": [],
      "source": [
        "baseline = tf.zeros(shape=(224,224,3))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vXRYwBWQS19B"
      },
      "outputs": [],
      "source": [
        "plt.imshow(baseline)\n",
        "plt.title(\"Baseline\")\n",
        "plt.axis('off')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xphryu2mGAk8"
      },
      "source": [
        "### Unpack formulas into code"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QstMR0IcbfFA"
      },
      "source": [
        "The formula for Integrated Gradients is as follows:\n",
        "\n",
        "$IntegratedGradients_{i}(x) ::= (x_{i} - x'_{i})\\times\\int_{\\alpha=0}^1\\frac{\\partial F(x'+\\alpha \\times (x - x'))}{\\partial x_i}{d\\alpha}$\n",
        "\n",
        "where:\n",
        "\n",
        "$_{i}$ = feature   \n",
        "$x$ = input  \n",
        "$x'$ = baseline   \n",
        "$\\alpha$ = interpolation constant to perturb features by"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_7w8SD5YMqvi"
      },
      "source": [
        "In practice, computing a definite integral is not always numerically possible and can be computationally costly, so you compute the following numerical approximation:\n",
        "\n",
        "$IntegratedGrads^{approx}_{i}(x)::=(x_{i}-x'_{i})\\times\\sum_{k=1}^{m}\\frac{\\partial F(x' + \\frac{k}{m}\\times(x - x'))}{\\partial x_{i}} \\times \\frac{1}{m}$\n",
        "\n",
        "where:\n",
        "\n",
        "$_{i}$ = feature (individual pixel)  \n",
        "$x$ = input (image tensor)  \n",
        "$x'$ = baseline (image tensor)  \n",
        "$k$ = scaled feature perturbation constant  \n",
        "$m$ = number of steps in the Riemann sum approximation of the integral  \n",
        "$(x_{i}-x'_{i})$ = a term for the difference from the baseline. This is necessary to scale the integrated gradients and keep them in terms of the original image. The path from the baseline image to the input is in pixel space. Since with IG you are integrating in a straight line (linear transformation) this ends up being roughly equivalent to the integral term of the derivative of the interpolated image function with respect to $\\alpha$ with enough steps. The integral sums each pixel's gradient times the change in the pixel along the path. It's simpler to implement this integration as uniform steps from one image to the other, substituting $x := (x' + \\alpha(x-x'))$. So the change of variables gives $dx = (x-x')d\\alpha$. The $(x-x')$ term is constant and is factored out of the integral."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5aPG68RssS2h"
      },
      "source": [
        "### Interpolate images"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pPrldEYsIR4M"
      },
      "source": [
        "$IntegratedGrads^{approx}_{i}(x)::=(x_{i}-x'_{i})\\times\\sum_{k=1}^{m}\\frac{\\partial F(\\overbrace{x' + \\frac{k}{m}\\times(x - x')}^\\text{interpolate m images at k intervals})}{\\partial x_{i}} \\times \\frac{1}{m}$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y4r-ZrIIsdbI"
      },
      "source": [
        "First, you will generate a [linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) between the baseline and the original image. You can think of interpolated images as small steps in the feature space between your baseline and input, represented by $\\alpha$ in the original equation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I42mBKXyjcIc"
      },
      "outputs": [],
      "source": [
        "m_steps=50\n",
        "alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1) # Generate m_steps intervals for integral_approximation() below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7SWLSFOHsbgh"
      },
      "outputs": [],
      "source": [
        "def interpolate_images(baseline,\n",
        "                       image,\n",
        "                       alphas):\n",
        "  alphas_x = alphas[:, tf.newaxis, tf.newaxis, tf.newaxis]\n",
        "  baseline_x = tf.expand_dims(baseline, axis=0)\n",
        "  input_x = tf.expand_dims(image, axis=0)\n",
        "  delta = input_x - baseline_x\n",
        "  images = baseline_x +  alphas_x * delta\n",
        "  return images"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s4zFzbUBj684"
      },
      "source": [
        "Let's use the above function to generate interpolated images along a linear path at alpha intervals between a black baseline image and the example \"Fireboat\" image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NgVx8swDQtTl"
      },
      "outputs": [],
      "source": [
        "interpolated_images = interpolate_images(\n",
        "    baseline=baseline,\n",
        "    image=img_name_tensors['Fireboat'],\n",
        "    alphas=alphas)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QABFsuCvkO1h"
      },
      "source": [
        "Let's visualize the interpolated images. Note: another way of thinking about the $\\alpha$ constant is that it is consistently increasing each interpolated image's intensity."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9tmBGdnHAupk"
      },
      "outputs": [],
      "source": [
        "fig = plt.figure(figsize=(20, 20))\n",
        "\n",
        "i = 0\n",
        "for alpha, image in zip(alphas[0::10], interpolated_images[0::10]):\n",
        "  i += 1\n",
        "  plt.subplot(1, len(alphas[0::10]), i)\n",
        "  plt.title(f'alpha: {alpha:.1f}')\n",
        "  plt.imshow(image)\n",
        "  plt.axis('off')\n",
        "\n",
        "plt.tight_layout();"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h7T0f1cqsaxA"
      },
      "source": [
        "### Compute gradients"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tps0eWc0REqL"
      },
      "source": [
        "Now let's take a look at how to calculate gradients in order to measure the relationship between changes to a feature and changes in the model's predictions. In the case of images, the gradient tells us which pixels have the strongest effect on the models predicted class probabilities."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ouuVIsdfgukW"
      },
      "source": [
        "$IntegratedGrads^{approx}_{i}(x)::=(x_{i}-x'_{i})\\times\\sum_{k=1}^{m}\\frac{\\overbrace{\\partial F(\\text{interpolated images})}^\\text{compute gradients}}{\\partial x_{i}} \\times \\frac{1}{m}$\n",
        "\n",
        "where:  \n",
        "$F()$ = your model's prediction function  \n",
        "$\\frac{\\partial{F}}{\\partial{x_i}}$ = gradient (vector of partial derivatives $\\partial$) of your model F's prediction function relative to each feature $x_i$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hY_Ok3CoJW1W"
      },
      "source": [
        "TensorFlow makes computing gradients easy for you with a [`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JW1O9qEsxZOP"
      },
      "outputs": [],
      "source": [
        "def compute_gradients(images, target_class_idx):\n",
        "  with tf.GradientTape() as tape:\n",
        "    tape.watch(images)\n",
        "    logits = model(images)\n",
        "    probs = tf.nn.softmax(logits, axis=-1)[:, target_class_idx]\n",
        "  return tape.gradient(probs, images)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9BfRuzx4-c87"
      },
      "source": [
        "Let's compute the gradients for each image along the interpolation path with respect to the correct output. Recall that your model returns a `(1, 1001)` shaped `Tensor` with logits that you convert to predicted probabilities for each class. You need to pass the correct ImageNet target class index to the `compute_gradients` function for your image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kHIR58rNJ3q_"
      },
      "outputs": [],
      "source": [
        "path_gradients = compute_gradients(\n",
        "    images=interpolated_images,\n",
        "    target_class_idx=555)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JJodSbDUQ3T_"
      },
      "source": [
        "Note the output shape of `(n_interpolated_images, img_height, img_width, RGB)`, which gives us the gradient for every pixel of every image along the interpolation path. You can think of these gradients as measuring the change in your model's predictions for each small step in the feature space."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v2rpO2JTbQId"
      },
      "outputs": [],
      "source": [
        "print(path_gradients.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zmaPpr5bUtbr"
      },
      "source": [
        "**Visualizing gradient saturation**\n",
        "\n",
        "Recall that the gradients you just calculated above describe *local* changes to your model's predicted probability of \"Fireboat\" and can *saturate*.\n",
        "\n",
        "These concepts are visualized using the gradients you calculated above in the 2 plots below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mCH8sAf3TTJ2"
      },
      "outputs": [],
      "source": [
        "pred = model(interpolated_images)\n",
        "pred_proba = tf.nn.softmax(pred, axis=-1)[:, 555]\n",
        "\n",
        "plt.figure(figsize=(10, 4))\n",
        "ax1 = plt.subplot(1, 2, 1)\n",
        "ax1.plot(alphas, pred_proba)\n",
        "ax1.set_title('Target class predicted probability over alpha')\n",
        "ax1.set_ylabel('model p(target class)')\n",
        "ax1.set_xlabel('alpha')\n",
        "ax1.set_ylim([0, 1])\n",
        "\n",
        "ax2 = plt.subplot(1, 2, 2)\n",
        "# Average across interpolation steps\n",
        "average_grads = tf.reduce_mean(path_gradients, axis=[1, 2, 3])\n",
        "# Normalize gradients to 0 to 1 scale. E.g. (x - min(x))/(max(x)-min(x))\n",
        "average_grads_norm = (average_grads-tf.math.reduce_min(average_grads))/(tf.math.reduce_max(average_grads)-tf.reduce_min(average_grads))\n",
        "ax2.plot(alphas, average_grads_norm)\n",
        "ax2.set_title('Average pixel gradients (normalized) over alpha')\n",
        "ax2.set_ylabel('Average pixel gradients')\n",
        "ax2.set_xlabel('alpha')\n",
        "ax2.set_ylim([0, 1]);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ntMpA87jNN6"
      },
      "source": [
        "* **left**: This plot shows how your model's confidence in the \"Fireboat\" class varies across alphas. Notice how the gradients, or slope of the line, largely flattens or saturates between 0.6 and 1.0 before settling at the final \"Fireboat\" predicted probability of about 40%.\n",
        "\n",
        "* **right**: The right plot shows the average gradients magnitudes over alpha more directly. Note how the values sharply approach and even briefly dip below zero. In fact, your model \"learns\" the most from gradients at lower values of alpha before saturating. Intuitively, you can think of this as your model has learned the pixels e.g. water cannons to make the correct prediction, sending these pixels gradients to zero, but is still quite uncertain and focused on spurious bridge or water jet pixels as the alpha values approach the original input image.\n",
        "\n",
        "To make sure these important water cannon pixels are reflected as important to the \"Fireboat\" prediction, you will continue on below to learn how to accumulate these gradients to accurately approximate how each pixel impacts your \"Fireboat\" predicted probability.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LQdACCM6sJdW"
      },
      "source": [
        "### Accumulate gradients (integral approximation)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QHopk9evmO5P"
      },
      "source": [
        "There are many different ways you can go about computing the numerical approximation of an integral for IG with different tradeoffs in accuracy and convergence across varying functions. A popular class of methods is called [Riemann sums](https://en.wikipedia.org/wiki/Riemann_sum). Here, you will use the Trapezoidal rule (you can find additional code to explore different approximation methods at the end of this tutorial)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GshPZQgROs80"
      },
      "source": [
        "$IntegratedGrads^{approx}_{i}(x)::=(x_{i}-x'_{i})\\times \\overbrace{\\sum_{k=1}^{m}}^\\text{Sum m local gradients}\n",
        "\\text{gradients(interpolated images)} \\times \\overbrace{\\frac{1}{m}}^\\text{Divide by m steps}$\n",
        "\n",
        "From the equation, you can see you are summing over `m` gradients and dividing by `m` steps. You can implement the two operations together for part 3 as an *average of the local gradients of `m` interpolated predictions and input images*."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1cMVl-Grx3lp"
      },
      "outputs": [],
      "source": [
        "def integral_approximation(gradients):\n",
        "  # riemann_trapezoidal\n",
        "  grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)\n",
        "  integrated_gradients = tf.math.reduce_mean(grads, axis=0)\n",
        "  return integrated_gradients"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QVQAHunkW79t"
      },
      "source": [
        "The `integral_approximation` function takes the gradients of the predicted probability of the target class with respect to the interpolated images between the baseline and the original image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JeF01fydNq0I"
      },
      "outputs": [],
      "source": [
        "ig = integral_approximation(\n",
        "    gradients=path_gradients)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7XVItLpurOAM"
      },
      "source": [
        "You can confirm averaging across the gradients of `m` interpolated images returns an integrated gradients tensor with the same shape as the original \"Giant Panda\" image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z1bP6l3ahfyn"
      },
      "outputs": [],
      "source": [
        "print(ig.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C1ODXUevyGxL"
      },
      "source": [
        "### Putting it all together"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NcaTR-x8v1At"
      },
      "source": [
        "Now you will combine the 3 previous general parts together into an `IntegratedGradients` function and utilize a [@tf.function](https://www.tensorflow.org/guide/function) decorator to compile it into a high performance callable TensorFlow graph. This is implemented as 5 smaller steps below:\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5YdWHscoovhk"
      },
      "source": [
        "$IntegratedGrads^{approx}_{i}(x)::=\\overbrace{(x_{i}-x'_{i})}^\\text{5.}\\times \\overbrace{\\sum_{k=1}^{m}}^\\text{4.} \\frac{\\partial \\overbrace{F(\\overbrace{x' + \\overbrace{\\frac{k}{m}}^\\text{1.}\\times(x - x'))}^\\text{2.}}^\\text{3.}}{\\partial x_{i}} \\times \\overbrace{\\frac{1}{m}}^\\text{4.}$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pjdfjp_3pHiY"
      },
      "source": [
        "1. Generate alphas $\\alpha$\n",
        "\n",
        "2. Generate interpolated images = $(x' + \\frac{k}{m}\\times(x - x'))$\n",
        "\n",
        "3. Compute gradients between model $F$ output predictions with respect to input features = $\\frac{\\partial F(\\text{interpolated path inputs})}{\\partial x_{i}}$\n",
        "\n",
        "4. Integral approximation through averaging gradients = $\\sum_{k=1}^m \\text{gradients} \\times \\frac{1}{m}$\n",
        "\n",
        "5. Scale integrated gradients with respect to original image = $(x_{i}-x'_{i}) \\times \\text{integrated gradients}$. The reason this step is necessary is to make sure that the attribution values accumulated across multiple interpolated images are in the same units and faithfully represent the pixel importances on the original image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O_H3k9Eu7Rl5"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def integrated_gradients(baseline,\n",
        "                         image,\n",
        "                         target_class_idx,\n",
        "                         m_steps=50,\n",
        "                         batch_size=32):\n",
        "  # 1. Generate alphas.\n",
        "  alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1)\n",
        "\n",
        "  # Initialize TensorArray outside loop to collect gradients.    \n",
        "  gradient_batches = tf.TensorArray(tf.float32, size=m_steps+1)\n",
        "    \n",
        "  # Iterate alphas range and batch computation for speed, memory efficiency, and scaling to larger m_steps.\n",
        "  for alpha in tf.range(0, len(alphas), batch_size):\n",
        "    from_ = alpha\n",
        "    to = tf.minimum(from_ + batch_size, len(alphas))\n",
        "    alpha_batch = alphas[from_:to]\n",
        "\n",
        "    # 2. Generate interpolated inputs between baseline and input.\n",
        "    interpolated_path_input_batch = interpolate_images(baseline=baseline,\n",
        "                                                       image=image,\n",
        "                                                       alphas=alpha_batch)\n",
        "\n",
        "    # 3. Compute gradients between model outputs and interpolated inputs.\n",
        "    gradient_batch = compute_gradients(images=interpolated_path_input_batch,\n",
        "                                       target_class_idx=target_class_idx)\n",
        "    \n",
        "    # Write batch indices and gradients to extend TensorArray.\n",
        "    gradient_batches = gradient_batches.scatter(tf.range(from_, to), gradient_batch)    \n",
        "  \n",
        "  # Stack path gradients together row-wise into single tensor.\n",
        "  total_gradients = gradient_batches.stack()\n",
        "\n",
        "  # 4. Integral approximation through averaging gradients.\n",
        "  avg_gradients = integral_approximation(gradients=total_gradients)\n",
        "\n",
        "  # 5. Scale integrated gradients with respect to input.\n",
        "  integrated_gradients = (image - baseline) * avg_gradients\n",
        "\n",
        "  return integrated_gradients"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8G0ELl_wRrd0"
      },
      "outputs": [],
      "source": [
        "ig_attributions = integrated_gradients(baseline=baseline,\n",
        "                                       image=img_name_tensors['Fireboat'],\n",
        "                                       target_class_idx=555,\n",
        "                                       m_steps=240)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P-LSHD2sajFf"
      },
      "source": [
        "Again, you can check that the IG feature attributions have the same shape as the input \"Fireboat\" image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "beoEunC5aZOa"
      },
      "outputs": [],
      "source": [
        "print(ig_attributions.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fLaB21txJhMq"
      },
      "source": [
        "The paper suggests the number of steps to range between 20 to 300 depending upon the example (although in practice this can be higher in the 1,000s to accurately approximate the integral). You can find additional code to check for the appropriate number of steps in the \"Next steps\" resources at the end of this tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o55W6NYXGSZ8"
      },
      "source": [
        "### Visualize attributions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XSQ6Y-DZvrQu"
      },
      "source": [
        "You are ready to visualize attributions, and overlay them on the original image. The code below sums the absolute values of the integrated gradients across the color channels to produce an attribution mask. This plotting method captures the relative impact of pixels on the model's predictions. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4QN2cEA_WFym"
      },
      "outputs": [],
      "source": [
        "def plot_img_attributions(baseline,\n",
        "                          image,\n",
        "                          target_class_idx,\n",
        "                          m_steps=50,\n",
        "                          cmap=None,\n",
        "                          overlay_alpha=0.4):\n",
        "\n",
        "  attributions = integrated_gradients(baseline=baseline,\n",
        "                                      image=image,\n",
        "                                      target_class_idx=target_class_idx,\n",
        "                                      m_steps=m_steps)\n",
        "\n",
        "  # Sum of the attributions across color channels for visualization.\n",
        "  # The attribution mask shape is a grayscale image with height and width\n",
        "  # equal to the original image.\n",
        "  attribution_mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1)\n",
        "\n",
        "  fig, axs = plt.subplots(nrows=2, ncols=2, squeeze=False, figsize=(8, 8))\n",
        "\n",
        "  axs[0, 0].set_title('Baseline image')\n",
        "  axs[0, 0].imshow(baseline)\n",
        "  axs[0, 0].axis('off')\n",
        "\n",
        "  axs[0, 1].set_title('Original image')\n",
        "  axs[0, 1].imshow(image)\n",
        "  axs[0, 1].axis('off')\n",
        "\n",
        "  axs[1, 0].set_title('Attribution mask')\n",
        "  axs[1, 0].imshow(attribution_mask, cmap=cmap)\n",
        "  axs[1, 0].axis('off')\n",
        "\n",
        "  axs[1, 1].set_title('Overlay')\n",
        "  axs[1, 1].imshow(attribution_mask, cmap=cmap)\n",
        "  axs[1, 1].imshow(image, alpha=overlay_alpha)\n",
        "  axs[1, 1].axis('off')\n",
        "\n",
        "  plt.tight_layout()\n",
        "  return fig"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n73VxzbxeMvD"
      },
      "source": [
        "Looking at the attributions on the \"Fireboat\" image, you can see the model identifies the water cannons and spouts as contributing to its correct prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vxCQFx96iDVs"
      },
      "outputs": [],
      "source": [
        "_ = plot_img_attributions(image=img_name_tensors['Fireboat'],\n",
        "                          baseline=baseline,\n",
        "                          target_class_idx=555,\n",
        "                          m_steps=240,\n",
        "                          cmap=plt.cm.inferno,\n",
        "                          overlay_alpha=0.4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lo4SncDZfTw0"
      },
      "source": [
        "On the \"Giant Panda\" image, the attributions highlight the texture, nose, and the fur of the Panda's face."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TcpGLJWuHnYl"
      },
      "outputs": [],
      "source": [
        "_ = plot_img_attributions(image=img_name_tensors['Giant Panda'],\n",
        "                          baseline=baseline,\n",
        "                          target_class_idx=389,\n",
        "                          m_steps=55,\n",
        "                          cmap=plt.cm.viridis,\n",
        "                          overlay_alpha=0.5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H3etJZHuI6hX"
      },
      "source": [
        "## Uses and limitations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xEX4Jh2uLvxA"
      },
      "source": [
        "Use cases\n",
        "* Employing techniques like Integrated Gradients before deploying your model can help you develop intuition for how and why it works. Do the features highlighted by this technique match your intuition? If not, that may be indicative of a bug in your model or dataset, or overfitting.\n",
        "\n",
        "Limitations\n",
        "* Integrated Gradients provides feature importances on individual examples, however, it does not provide global feature importances across an entire dataset.\n",
        "\n",
        "* Integrated Gradients provides individual feature importances, but it does not explain feature interactions and combinations."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ejc2Ho_8i162"
      },
      "source": [
        "## Next steps"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G8Ra5ijj7pEc"
      },
      "source": [
        "This tutorial presented a basic implementation of Integrated Gradients. As a next step, you can use this notebook to try this technique with different models and images yourself.\n",
        "\n",
        "For interested readers, there is a lengthier version of this tutorial (which includes code for different baselines, to compute integral approximations, and to determine a sufficient number of steps) which you can find [here](https://github.com/GoogleCloudPlatform/training-data-analyst/tree/master/blogs/integrated_gradients).\n",
        "\n",
        "To deepen your understanding, check out the paper [Axiomatic Attribution for Deep Networks](https://arxiv.org/abs/1703.01365) and [Github repository](https://github.com/ankurtaly/Integrated-Gradients), which contains an implementation in a previous version of TensorFlow. You can also explore feature attribution, and the impact of different baselines, on [distill.pub](https://distill.pub/2020/attribution-baselines/).\n",
        "\n",
        "Interested in incorporating IG into your production machine learning workflows for feature importances, model error analysis, and data skew monitoring? Check out Google Cloud's [Explainable AI](https://cloud.google.com/explainable-ai) product that supports IG attributions. The Google AI PAIR research group also open-sourced the [What-if tool](https://pair-code.github.io/what-if-tool/index.html#about) which can be used for model debugging, including visualizing IG feature attributions."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "integrated_gradients.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
